from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
import ctypes
import sys
import os

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
    open_lib,
    to_tensor,
    DeviceEnum,
    infiniopHandle_t,
    infiniopTensorDescriptor_t,
    create_handle,
    destroy_handle,
    check_error,
    rearrange_tensor,
    create_workspace,
    U64,
)

from operatorspy.tests.test_utils import get_args
import torch


class RandomSampleDescriptor(Structure):
    _fields_ = [("device", c_int32)]


infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor)


def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
    indices = torch.zeros([topk], dtype = torch.int64)
    dataNp = data.clone().detach()
    sorted_indices = torch.arange(voc)
    
    for i in range(topk):
        for j in range(i + 1, voc):
            if(dataNp[i] < dataNp[j]):
                tmp = dataNp[i].clone().detach()
                dataNp[i] = dataNp[j].clone().detach()
                dataNp[j] = tmp

                tmpInd = sorted_indices[i].clone().detach()
                sorted_indices[i] = sorted_indices[j].clone().detach()
                sorted_indices[j] = tmpInd
                
    #sorted_indices = torch.argsort(dataNp, descending=True)
    indices = sorted_indices[:topk] 
    
    dataNp = dataNp[sorted_indices]
    
    globalM = dataNp[0]
    dataNp = (dataNp - globalM) / temperature
    dataNp = torch.softmax(dataNp.float(), dim = 0)
    sum_s = 0
    for end in range(topk):
        sum_s += dataNp[end]
        if(sum_s >= topp):
            break
    if(end < topk - 1):
        end += 1
    else:
        end = topk
    
    sum_s = 0
    for i in range(end):
        sum_s += dataNp[i]
    random_val *= sum_s
    
    sum_s = 0
    for i in range(end):
        sum_s += dataNp[i]
        if(random_val < sum_s):
            return indices[i]

def random_sample_0(data):
    return torch.argmax(data)

def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_dtype=torch.float16):
    print(
        f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}"
    )
    data = torch.arange(voc).float() * 0.0001
    _perm = torch.randperm(voc)
    if (torch_device == 'maca'):
        data = data[_perm].to(x_dtype).to('cuda')
    else:
        data = data[_perm].to(x_dtype).to(torch_device)
    if(topp > 0 and topk > 1):
        ans = random_sample(data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu")
    else:
        ans = random_sample_0(data)
    if(torch_device == 'maca'):
        indices = torch.zeros([1], dtype = torch.int64).to('cuda')
    else:
        indices = torch.zeros([1], dtype = torch.int64).to(torch_device)
    x_tensor = to_tensor(data, lib)
    indices_tensor = to_tensor(indices, lib)
    indices_tensor.descriptor.contents.dt = U64  # treat int64 as uint64

    descriptor = infiniopRandomSampleDescriptor_t()
    check_error(
        lib.infiniopCreateRandomSampleDescriptor(
            handle, ctypes.byref(descriptor), indices_tensor.descriptor, x_tensor.descriptor
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
    x_tensor.descriptor.contents.invalidate()
    indices_tensor.descriptor.contents.invalidate()

    workspace_size = c_uint64(0)
    check_error(
        lib.infiniopGetRandomSampleWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
    )
    workspace = create_workspace(workspace_size.value, torch_device) 
    check_error(
        lib.infiniopRandomSample(
            descriptor,
            workspace.data_ptr() if workspace is not None else None,
            workspace_size.value,
            indices_tensor.data,
            x_tensor.data,
            random_val,
            topp,
            topk,
            temperature,
            None,
        )
    )
    if torch_device == "npu":
        torch.npu.synchronize()

    assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]]
    check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))

def test_cpu(lib, test_cases):
    device = DeviceEnum.DEVICE_CPU
    handle = create_handle(lib, device)
    for (voc, random_val, topp, topk, temperature) in test_cases:
        test(lib, handle, "cpu", voc, random_val, topp, topk, temperature)
    destroy_handle(lib, handle)


def test_cuda(lib, test_cases):
    device = DeviceEnum.DEVICE_CUDA
    handle = create_handle(lib, device)
    for (voc, random_val, topp, topk, temperature) in test_cases:
        test(lib, handle, "cuda", voc, random_val, topp, topk, temperature)
    destroy_handle(lib, handle)


def test_bang(lib, test_cases):
    import torch_mlu

    device = DeviceEnum.DEVICE_BANG
    handle = create_handle(lib, device)
    for (voc, random_val, topp, topk, temperature) in test_cases:
        test(lib, handle, "mlu", voc, random_val, topp, topk, temperature)
    destroy_handle(lib, handle)


def test_ascend(lib, test_cases):
    import torch_npu
    device = DeviceEnum.DEVICE_ASCEND
    handle = create_handle(lib, device)
    for (voc, random_val, topp, topk, temperature) in test_cases:
        test(lib, handle, "npu", voc, random_val, topp, topk, temperature)
    destroy_handle(lib, handle) 
    
def test_maca(lib, test_cases):
    device = DeviceEnum.DEVICE_MACA
    handle = create_handle(lib, device)
    for (voc, random_val, topp, topk, temperature) in test_cases:
        test(lib, handle, "maca", voc, random_val, topp, topk, temperature)
    destroy_handle(lib, handle) 
    

def test_musa(lib, test_cases):
    import torch_musa
    device = DeviceEnum.DEVICE_MUSA
    handle = create_handle(lib, device)
    for (voc, random_val, topp, topk, temperature) in test_cases:
        test(lib, handle, "musa", voc, random_val, topp, topk, temperature)
    destroy_handle(lib, handle) 

if __name__ == "__main__":
    test_cases = [
        # voc, random_val, topp, topk, temperature
        (512, 0.8, 0.8, 3, 0.5),
        (4096, 0.05, 0.9, 5, 1.0),
        (16384, 0.15, 0.85, 10, 2.0),
        (512, 0.08, 0, 3, 0.5),
        (4096, 0.5, 0.9, 1, 1.0),
        (16384, 0.15, 0, 1, 2.0),
        (16384, 0.15, 0, 1, 2.0),
        (32000, 0.08, 0.8, 50, 1.0),
        (32000, 0.08, 1.0, 25, 1.0),
        # (119696, 0.01, 1.0, 100, 1.0),
    ]
    
    args = get_args()
    lib = open_lib()
    lib.infiniopCreateRandomSampleDescriptor.restype = c_int32
    lib.infiniopCreateRandomSampleDescriptor.argtypes = [
        infiniopHandle_t,
        POINTER(infiniopRandomSampleDescriptor_t),
        infiniopTensorDescriptor_t,
    ]
    lib.infiniopGetRandomSampleWorkspaceSize.restype = c_int32
    lib.infiniopGetRandomSampleWorkspaceSize.argtypes = [
        infiniopRandomSampleDescriptor_t,
        POINTER(c_uint64),
    ]
    lib.infiniopRandomSample.restype = c_int32
    lib.infiniopRandomSample.argtypes = [
        infiniopRandomSampleDescriptor_t,
        c_void_p,
        c_uint64,
        c_uint64,
        c_void_p,
        c_float,
        c_float,
        c_int32,
        c_float,
        c_void_p,
    ]
    lib.infiniopDestroyRandomSampleDescriptor.restype = c_int32
    lib.infiniopDestroyRandomSampleDescriptor.argtypes = [
        infiniopRandomSampleDescriptor_t,
    ]

    if args.cpu:
        test_cpu(lib, test_cases)
    if args.cuda:
        test_cuda(lib, test_cases)
    if args.bang:
        test_bang(lib, test_cases)
    if args.ascend:
        test_ascend(lib, test_cases)
    if args.maca:
        test_maca(lib, test_cases)
    if args.musa:
        test_musa(lib, test_cases)
    if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa):
        test_cpu(lib, test_cases)
    print("\033[92mTest passed!\033[0m")
